{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.02186339, -0.00224868, -0.04336443, -0.00508288], dtype=float32)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self):\n",
    "        env = gym.make('CartPole-v1')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, done, _, info = self.env.step(action)\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "MyWrapper().reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "WTHPMs_s7YFc"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(<stable_baselines3.common.vec_env.dummy_vec_env.DummyVecEnv at 0x7f072dfdb340>,\n",
       " <Monitor<MyWrapper<TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>>>)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from stable_baselines3.common.env_util import make_vec_env\n",
    "from stable_baselines3.common.monitor import Monitor\n",
    "\n",
    "#创建训练环境和测试环境\n",
    "env_train = make_vec_env(MyWrapper, n_envs=4)\n",
    "env_test = Monitor(MyWrapper())\n",
    "\n",
    "env_train, env_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "31fb9dc4a552416b9a1800343da49e0c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "71.1597838528824"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from stable_baselines3 import PPO\n",
    "from stable_baselines3.common.evaluation import evaluate_policy\n",
    "\n",
    "\n",
    "#测试超参数\n",
    "def test_params(params):\n",
    "    #定义一个模型\n",
    "    model = PPO(\n",
    "        policy='MlpPolicy',\n",
    "        env=env_train,\n",
    "        n_steps=1024,\n",
    "        batch_size=64,\n",
    "        #取超参数\n",
    "        n_epochs=params['n_epochs'],\n",
    "        #取超参数\n",
    "        gamma=params['gamma'],\n",
    "        gae_lambda=0.98,\n",
    "        ent_coef=0.01,\n",
    "        verbose=0,\n",
    "    )\n",
    "\n",
    "    #训练\n",
    "    #取超参数\n",
    "    model.learn(total_timesteps=params['total_timesteps'], progress_bar=True)\n",
    "\n",
    "    #测试\n",
    "    mean_reward, std_reward = evaluate_policy(model,\n",
    "                                              env_test,\n",
    "                                              n_eval_episodes=50,\n",
    "                                              deterministic=True)\n",
    "\n",
    "    #最终的分数就是简单的求差,这也是study要优化的数\n",
    "    score = mean_reward - std_reward\n",
    "\n",
    "    return score\n",
    "\n",
    "\n",
    "test_params({'n_epochs': 2, 'gamma': 0.99, 'total_timesteps': 500})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m[I 2023-01-17 17:10:12,768]\u001b[0m A new study created in memory with name: PPO-LunarLander-v2\u001b[0m\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "392a76f861194caa94604c5f6b26f89f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_3964/3720689107.py:15: FutureWarning: suggest_uniform has been deprecated in v3.0.0. This feature will be removed in v6.0.0. See https://github.com/optuna/optuna/releases/tag/v3.0.0. Use :func:`~optuna.trial.Trial.suggest_float` instead.\n",
      "  'gamma': trial.suggest_uniform('gamma', 0.99, 0.9999),\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m[I 2023-01-17 17:10:26,030]\u001b[0m Trial 0 finished with value: 102.40188282816155 and parameters: {'n_epochs': 5, 'gamma': 0.9963042134418639, 'total_timesteps': 580}. Best is trial 0 with value: 102.40188282816155.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6f6718a328d0489e8eb64ecfc51e3ea7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m[I 2023-01-17 17:10:39,938]\u001b[0m Trial 1 finished with value: 76.35171294486264 and parameters: {'n_epochs': 3, 'gamma': 0.9901653456845795, 'total_timesteps': 1082}. Best is trial 0 with value: 102.40188282816155.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2aa1b314c8c54d81be17e4f8714b0725",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m[I 2023-01-17 17:10:53,935]\u001b[0m Trial 2 finished with value: 100.10383872771487 and parameters: {'n_epochs': 5, 'gamma': 0.9953326653547164, 'total_timesteps': 1338}. Best is trial 0 with value: 102.40188282816155.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2661f05a37ec49de8bd18c8cc0b84c15",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m[I 2023-01-17 17:11:00,310]\u001b[0m Trial 3 finished with value: 88.37066386696605 and parameters: {'n_epochs': 3, 'gamma': 0.9988387230407516, 'total_timesteps': 1090}. Best is trial 0 with value: 102.40188282816155.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "be4fb6be9441444ca29655a8b3392292",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m[I 2023-01-17 17:11:09,927]\u001b[0m Trial 4 finished with value: 29.38088676826756 and parameters: {'n_epochs': 3, 'gamma': 0.9970100515891315, 'total_timesteps': 780}. Best is trial 0 with value: 102.40188282816155.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "([102.40188282816155],\n",
       " {'n_epochs': 5, 'gamma': 0.9963042134418639, 'total_timesteps': 580})"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import optuna\n",
    "from optuna.samplers import TPESampler\n",
    "\n",
    "#定义一个超参数学习器\n",
    "study = optuna.create_study(sampler=TPESampler(),\n",
    "                            study_name='PPO-LunarLander-v2',\n",
    "                            direction='maximize')\n",
    "\n",
    "\n",
    "#求最优超参数\n",
    "def f(trial):\n",
    "    #定义要找的超参数,并设置上下限\n",
    "    params = {\n",
    "        'n_epochs': trial.suggest_int('n_epochs', 3, 5),\n",
    "        'gamma': trial.suggest_uniform('gamma', 0.99, 0.9999),\n",
    "        'total_timesteps': trial.suggest_int('total_timesteps', 500, 2000),\n",
    "    }\n",
    "\n",
    "    #测试超参数\n",
    "    return test_params(params)\n",
    "\n",
    "\n",
    "study.optimize(f, n_trials=5)\n",
    "\n",
    "#输出最佳分数和超参数\n",
    "study.best_trial.values, study.best_trial.params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "8JVUHgljIW1E"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "940b8581b32b4b23822d3eda0a93cc22",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "133.760702749365"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#用最优超参数训练一个模型\n",
    "test_params(study.best_trial.params)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "Unit 1 Special Content: Optuna Guide.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
